import os
import numpy as np
import torch
import dgl
import torch.optim as optim
from model import *
from utils import *
import json


import warnings
warnings.filterwarnings('ignore')


if __name__ == '__main__':
    args = parse_args()
    setup_seed(2)  # 设置随机种子为72
    device = torch.device(args.cuda)
    args.device = device
    dataset_path = args.data_path+args.dataset+'.dgl'
    model_path = args.result_path+args.dataset+'_model.pt'
    log_path = args.result_path+args.dataset+'_log.json'
    results = {'F1-macro':[],'AUC':[],'G-Mean':[],'recall':[],'ACC1':[],'ACC0':[]}
    if not os.path.exists(args.result_path):
        os.makedirs(args.result_path)
    '''
    # load dataset and normalize feature
    '''
    dataset = dgl.load_graphs(dataset_path)[0][0]
    features = dataset.ndata['feature'].numpy()
    features = normalize(features)
    dataset.ndata['feature'] = torch.from_numpy(features).float()
    
    # 修改训练集为只有一个正样本和一个负样本
    train_mask = dataset.ndata['train_mask'].bool()
    train_labels = dataset.ndata['label'][train_mask]
    positive_indices = (train_labels == 1).nonzero().flatten()
    negative_indices = (train_labels == 0).nonzero().flatten()
    
    # 设置随机种子以确保可重复性
    torch.manual_seed(72)  # 使用与setup_seed相同的种子值

    # 随机选择一个正样本和一个负样本
    if len(positive_indices) > 0:
        random_pos_idx = torch.randint(len(positive_indices), (1,)).item()
        selected_positive = positive_indices[random_pos_idx:random_pos_idx+1]
    else:
        selected_positive = torch.tensor([], dtype=torch.long)
        print("没有可用的正样本")

    if len(negative_indices) > 0:
        random_neg_idx = torch.randint(len(negative_indices), (1,)).item()
        selected_negative = negative_indices[random_neg_idx:random_neg_idx+1]
    else:
        selected_negative = torch.tensor([], dtype=torch.long)
        print("没有可用的负样本")
    
    # 创建新的训练掩码
    new_train_mask = torch.zeros_like(dataset.ndata['train_mask'])
    new_train_mask[train_mask.nonzero().flatten()[selected_positive]] = 1
    new_train_mask[train_mask.nonzero().flatten()[selected_negative]] = 1
    dataset.ndata['train_mask'] = new_train_mask
    
    dataset = dataset.to(device)
    
    '''
    # train model
    '''
    print('Start training model...')
    model = H2FDetector(args, dataset)
    model = model.to(device)
    optimizer = optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    early_stop = EarlyStop(args.early_stop)
    
    valid_logs = []
    
    for e in range(args.epoch):
        
        model.train()
        loss = model.loss(dataset)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            '''
            # valid
            '''
            model.eval()
            valid_mask = dataset.ndata['valid_mask'].bool()
            valid_labels_all = dataset.ndata['label'][valid_mask].cpu().numpy()
            
            # 过滤掉标签为2的无标签样本
            valid_labeled_indices = np.where(valid_labels_all != 2)[0]
            valid_labels = valid_labels_all[valid_labeled_indices]
            valid_logits_all = model(dataset)[valid_mask]
            valid_logits = valid_logits_all[valid_labeled_indices]
            valid_preds = valid_logits.argmax(1).cpu().numpy()
            
            # 计算验证集上的指标
            f1_macro, auc, gmean, recall = evaluate(valid_labels, valid_logits)
            
            # 计算ACC1和ACC0
            valid_preds_np = valid_preds
            valid_labels_np = valid_labels
            
            # 正样本准确率ACC1
            valid_pos_indices = np.where(valid_labels_np == 1)[0]
            acc1 = np.mean(valid_preds_np[valid_pos_indices] == valid_labels_np[valid_pos_indices]) if len(valid_pos_indices) > 0 else 0
            
            # 负样本准确率ACC0
            valid_neg_indices = np.where(valid_labels_np == 0)[0]
            acc0 = np.mean(valid_preds_np[valid_neg_indices] == valid_labels_np[valid_neg_indices]) if len(valid_neg_indices) > 0 else 0
            
            # 记录验证日志
            valid_log = {
                'epoch': e,
                'loss': loss.item(),
                'F1-macro': f1_macro,
                'AUC': auc,
                'G-Mean': gmean,
                'recall': recall,
                'ACC1': acc1,
                'ACC0': acc0
            }
            valid_logs.append(valid_log)
            
            if args.log:
                print(f'{e}: Best Epoch:{early_stop.best_epoch}, Best valid AUC:{early_stop.best_eval}, Loss:{loss.item()}, Current valid: Recall:{recall}, F1_macro:{f1_macro}, G-Mean:{gmean}, AUC:{auc}, ACC1:{acc1}, ACC0:{acc0}')
            do_store, do_stop = early_stop.step(auc, e)
            if do_store:
                torch.save(model, model_path)
            if do_stop:
                break
    print('End training')
    
    '''
    # test model
    '''
    print('Test model...')
    model = torch.load(model_path)      
    with torch.no_grad():
        model.eval()
        test_mask = dataset.ndata['test_mask'].bool()
        test_labels_all = dataset.ndata['label'][test_mask].cpu().numpy()
        
        # 过滤掉标签为2的无标签样本
        test_labeled_indices = np.where(test_labels_all != 2)[0]
        test_labels = test_labels_all[test_labeled_indices]
        
        logits_all = model(dataset)[test_mask]
        logits = logits_all[test_labeled_indices]
        test_preds = logits.argmax(1).cpu().numpy()
        logits = logits.cpu()
        test_result_path = args.result_path+args.dataset
        f1_macro, auc, gmean, recall = evaluate(test_labels, logits, test_result_path)
        
        # 计算测试集上的ACC1和ACC0
        test_pos_indices = np.where(test_labels == 1)[0]
        test_acc1 = np.mean(test_preds[test_pos_indices] == test_labels[test_pos_indices]) if len(test_pos_indices) > 0 else 0
        
        test_neg_indices = np.where(test_labels == 0)[0]
        test_acc0 = np.mean(test_preds[test_neg_indices] == test_labels[test_neg_indices]) if len(test_neg_indices) > 0 else 0
        
        results['F1-macro'].append(f1_macro)
        results['AUC'].append(auc)
        results['G-Mean'].append(gmean)
        results['recall'].append(recall)
        results['ACC1'].append(test_acc1)
        results['ACC0'].append(test_acc0)
        
        print(f'Test: F1-macro:{f1_macro}, AUC:{auc}, G-Mean:{gmean}, Recall:{recall}, ACC1:{test_acc1}, ACC0:{test_acc0}')
    
    # 保存所有日志
    log_data = {
        'valid_logs': valid_logs,
        'test_results': results
    }
    
    with open(log_path, 'w') as f:
        json.dump(log_data, f, indent=4)
    

